Skip to content

[PyTorch] Add dtype information to QuantizedTensorStorage class#2676

Open
ptrendx wants to merge 5 commits intoNVIDIA:mainfrom
ptrendx:pr_dtype_in_storage
Open

[PyTorch] Add dtype information to QuantizedTensorStorage class#2676
ptrendx wants to merge 5 commits intoNVIDIA:mainfrom
ptrendx:pr_dtype_in_storage

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Feb 12, 2026

Description

This PR adds the fake dtype information to the QuantizedTensorStorage class. This eliminates the need to guess the correct type for dequantize, as was the case in the distributed.py, and it eliminates the unintentional dequantization to FP32 when calling dequantize() on the Storage class with no dtype argument.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added the _dtype field to the QuantizedTensorStorage class
  • Modified the dequantize call to use that new field when calling dequantize with no arguments
  • Removed guessing of the dtype from distributed.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from timmoon10 February 12, 2026 19:07
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Summary

This PR adds a _dtype field (the "fake" high-precision dtype) to every QuantizedTensorStorage subclass and threads it through constructors, get_metadata(), and dequantize() defaults. This cleanly removes the hardcoded torch.bfloat16 guesses in distributed.py and ensures that dequantize() called without arguments no longer silently degrades to FP32.

Key changes and concerns:

  • fake_dtype validation gate in QuantizedTensor.__new__ — The guard if fake_dtype is not None and fake_dtype != dtype: raise ValueError(...) (line 376–377 of quantized_tensor.py) is logically correct in isolation, but since get_metadata() now always emits fake_dtype=self._dtype, any code that calls make_like(tensor, dtype=new_dtype) with a different dtype will immediately hit this ValueError. There are 10+ existing call sites in the codebase (attention/dot_product_attention/context_parallel.py, utils.py, tensor/__init__.py) that pass an explicit dtype to make_like, and the module-level cast_to_dtype helper (used by model.half() / model.bfloat16()) iterates over all QuantizedTensor parameters using exactly this pattern. This is a regression that will surface in normal training workflows.
  • No backward-compatibility guard for old pickled storage objects — The new _dtype field is only set in __new__; there is no hasattr fallback in dequantize() (unlike the dtype property on QuantizedTensor). Unpickling a *TensorStorage that was saved before this PR will result in AttributeError: _dtype on the first dequantize call.
  • Storage classes correctly updated — All four storage classes (Float8TensorStorage, Float8BlockwiseQTensorStorage, MXFP8TensorStorage, NVFP4TensorStorage) follow a consistent pattern and correctly dispatch between storage-only (object.__new__) and full-tensor (super().__new__) paths.

Confidence Score: 2/5

  • Not safe to merge as-is — the fake_dtype validation in QuantizedTensor.__new__ will break existing make_like(dtype=X) call sites throughout the attention module and the module-level dtype-cast utility.
  • The core idea (storing the high-precision dtype on storage objects) is sound and the distributed.py and C++ changes are clean improvements. However, the new validation guard in QuantizedTensor.__new__ combined with fake_dtype being unconditionally included in get_metadata() creates a regression: any call to make_like or to_dtype that changes the nominal dtype — including the widely-used cast_to_dtype module helper — will now raise a ValueError. This is a blocking correctness issue for normal training workflows.
  • transformer_engine/pytorch/quantized_tensor.py — the fake_dtype validation logic and its interaction with get_metadata() requires the most attention before merging.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds _dtype annotation to QuantizedTensorStorage and a fake_dtype parameter to QuantizedTensor.__new__. The new validation guard (fake_dtype != dtype → ValueError) is correct in isolation but breaks all make_like(dtype=X) call sites where X differs from the tensor's current _dtype, because get_metadata() now injects fake_dtype matching the old dtype.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds fake_dtype parameter to __new__, propagates it via super().__new__ for the full-tensor path, stores it as _dtype for the storage-only path, and adds it to get_metadata(). Logic is correct; view() and get_metadata() now preserve _dtype correctly.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Previously always called super().__new__() regardless of cls, corrected to the same storage-vs-full-tensor dispatch pattern as other storage classes. fake_dtype propagated correctly throughout.
transformer_engine/pytorch/distributed.py Removes three hardcoded dtype = torch.bfloat16 guesses and replaces them with dtype = inp._dtype, which is now reliably set on all storage objects produced by this PR. Clean improvement.
transformer_engine/pytorch/csrc/quantizer.cpp Adds kwargs["fake_dtype"] = GetATenDType(dtype) to all five create_tensor call sites in the internal (storage-only) C++ path. The full-tensor path already passes dtype as a top-level constructor arg so no change needed there.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["C++ Quantizer::create_tensor(dtype)"] -->|"internal=true"| B["Float8TensorStorage.__new__\n(fake_dtype=GetATenDType(dtype))"]
    A -->|"internal=false"| C["Float8Tensor.__new__\n(dtype=GetATenDType(dtype))"]

    B -->|"cls is Storage"| D["object.__new__()\ninstance._dtype = fake_dtype"]
    B -->|"cls is Tensor subclass"| E["super().__new__(cls, fake_dtype=fake_dtype)"]
    C --> E

    E --> F["QuantizedTensor.__new__(dtype, fake_dtype)\nValidate: fake_dtype == dtype if not None\ninstance._dtype = dtype"]

    F --> G["QuantizedTensor._dtype set"]
    D --> G

    G -->|"dequantize() called"| H{"dtype arg?"}
    H -->|"None"| I["use self._dtype"]
    H -->|"explicit"| J["use explicit dtype"]
    I --> K["dequantize to correct high-precision dtype"]
    J --> K

    style D fill:#90EE90
    style G fill:#90EE90
    style F fill:#FFB6C1,stroke:#FF0000
    note1["⚠️ Validation fails for\nmake_like(dtype=X) when X != _dtype"]
    F -.-> note1
Loading

Last reviewed commit: 369f8b5

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx
Copy link
Member Author

ptrendx commented Feb 12, 2026

/te-ci pytorch

ksivaman
ksivaman previously approved these changes Feb 12, 2026
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

timmoon10
timmoon10 previously approved these changes Feb 14, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is a big improvement. I have some naming nits.

shape: Iterable[int],
dtype: torch.dtype,
*,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this redundant with the dtype kwarg?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to avoid issues with MRO and still have fairly straightforward constructors for the Storage classes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just noticed that the make_like call would be problematic there otherwise - we want to include the fake_dtype in get_metadata call, but if it was named dtype it would clash with the dtype that we pass directly in make_like.

data: Optional[torch.Tensor],
fp8_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to just name it dtype since QuantizedTensor is already using that name in its constructor.

Suggested change
fake_dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx dismissed stale reviews from timmoon10 and ksivaman via be723b2 February 18, 2026 01:40
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

timmoon10
timmoon10 previously approved these changes Feb 19, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still not a fan of fake_dtype, but approving to unblock.

@ptrendx
Copy link
Member Author

ptrendx commented Feb 24, 2026

/te-ci pytorch

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Member Author

ptrendx commented Mar 4, 2026

/te-ci pytorch

Comment on lines +376 to +377
if fake_dtype is not None and fake_dtype != dtype:
raise ValueError(f"fake_dtype ({fake_dtype}) does not match dtype ({dtype})")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation breaks existing make_like call sites

This new guard will cause regressions on every call to make_like(tensor, dtype=X) where X differs from tensor._dtype, because get_metadata() now always injects fake_dtype=self._dtype into kwargs, and QuantizedTensor.__new__ is then called with both dtype=X (the intended new dtype) and fake_dtype=old_dtype (from metadata).

Confirmed breakage paths:

  • transformer_engine/pytorch/tensor/__init__.py:63 — module cast utility (model.half(), model.bfloat16(), etc.) calls tensor.__class__.make_like(tensor, dtype=dtype) for every QuantizedTensor; whenever dtype != tensor._dtype the model cast will raise ValueError.
  • attention/dot_product_attention/context_parallel.py — 10+ call sites of the form Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) where fwd_nominal_dtype may differ from x._dtype.
  • attention/dot_product_attention/utils.py:2220 — same pattern.

The root cause is that fake_dtype is being included in get_metadata() but the constructor-level guard then rejects any case where the caller wants to create a clone at a different nominal dtype. Either:

  1. Remove the guard (it is redundant for the full-tensor path, because QuantizedTensor.__new__ already sets _dtype = dtype), or
  2. Override fake_dtype in QuantizedTensor.make_like so it matches the requested dtype before calling the constructor.

Comment on lines +40 to 41
_dtype: torch.dtype
_quantizer: Optional[Quantizer]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No lazy-init guard for _dtype on storage objects

QuantizedTensor.dtype (line 405–409) has a hasattr(self, "_dtype") lazy-initializer that protects against deserialization from pre-PR checkpoints. QuantizedTensorStorage and its subclasses have no equivalent protection — _dtype: torch.dtype is only a class-level annotation, not a default value.

If an *TensorStorage object is unpickled from a checkpoint that was saved before this PR, the first call to .dequantize() (or the distributed-ops in distributed.py that now access inp._dtype) will raise AttributeError: _dtype.

Consider adding a similar lazy fallback in the dequantize methods, e.g.:

if dtype is None:
    dtype = getattr(self, "_dtype", torch.float32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants